Skip to content

Conversation

@banach-space
Copy link
Contributor

@banach-space banach-space commented Jan 2, 2025

This patch enforces a restriction in the Vector dialect: the non-indexed
operands of vector.insert and vector.extract must no longer be 0-D
vectors. In other words, rank-0 vector types like vector<f32> are
disallowed as the source or result.

EXAMPLES

The following are now illegal (note the use of vector<f32>):

%0 = vector.insert %v, %dst[0, 0] : vector<f32> into vector<2x2xf32>
%1 = vector.extract %src[0, 0] : vector<f32> from vector<2x2xf32>

Instead, use scalars as the source and result types:

  %0 = vector.insert %v, %dst[0, 0] : f32 into vector<2x2xf32>
  %1 = vector.extract %src[0, 0] : f32 from vector<2x2xf32>

Note, this change serves three goals. These are summarised below.

1. REDUCED AMBIGUITY

By enforcing scalar-only semantics when the result (vector.extract)
or source (vector.insert) are rank-0, we eliminate ambiguity
in interpretation. Prior to this patch, both f32 and vector<f32>
were accepted.

2. MATCH IMPLEMENTATION TO DOCUMENTATION

The current behaviour contradicts the documented intent. For example,
vector.extract states:

Degenerates to an element type if n-k is zero.

This patch enforces that intent in code.

3. ENSURE SYMMETRY BETWEEN INSERT AND EXTRACT

With the stricter semantics in place, it’s natural and consistent to
make vector.insert behave symmetrically to vector.extract, i.e.,
degenerate the source type to a scalar when n = 0.

NOTES FOR REVIEWERS

  1. Main change is in "VectorOps.cpp", where stricter type checks are
    implemented.
  2. Test updates in "invalid.mlir" and "ops.mlir" are minor cleanups to
    remove now-illegal examples.
  3. Lowering changes in "VectorToSCF.cpp" are the main trade-off: we now
    require an additional vector.extract when a preceding
    vector.transfer_read generates a rank-0 vector.

RELATED RFC

@banach-space
Copy link
Contributor Author

While the discussion is ongoing, I am posting this as draft. Please comment either here or on Discourse.

@banach-space banach-space force-pushed the andrzej/restrict_vec_insert_extract branch from 0962edb to 473bd0f Compare January 6, 2025 09:01
@github-actions
Copy link

github-actions bot commented Jan 6, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

@banach-space banach-space force-pushed the andrzej/restrict_vec_insert_extract branch from 473bd0f to 9434337 Compare January 6, 2025 09:08
@banach-space banach-space force-pushed the andrzej/restrict_vec_insert_extract branch from 9434337 to eb62bf0 Compare February 8, 2025 15:24
@banach-space banach-space force-pushed the andrzej/restrict_vec_insert_extract branch from eb62bf0 to 4084038 Compare April 29, 2025 10:39
@banach-space banach-space force-pushed the andrzej/restrict_vec_insert_extract branch 2 times, most recently from c214bbf to 29e7bb6 Compare May 30, 2025 10:42
@banach-space banach-space force-pushed the andrzej/restrict_vec_insert_extract branch from 29e7bb6 to 9f9491a Compare June 12, 2025 15:46
@banach-space banach-space marked this pull request as ready for review June 12, 2025 16:41
@llvmbot
Copy link
Member

llvmbot commented Jun 12, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-vector

Author: Andrzej Warzyński (banach-space)

Changes

This patch restricts the use of vector.insert and vector.extract Ops in
the Vector dialect. Specifically:

  • The non-indexed operands for vector.insert and vector.extract
    must now be non-0-D vectors.

The following are now illegal. Note that the source and result types
(i.e. non-indexed args) are rank-0 vectors:

  %0 = vector.insert %v, %dst[0, 0] : vector&lt;f32&gt; into vector&lt;2x2xf32&gt;
  %1 = vector.extract %arg0[0, 0] : vector&lt;f32&gt; from vector&lt;2x2xf32&gt;

Instead, use scalars as the source and result types:

  %0 = vector.insert %v, %dst[0, 0] : f32 into vector&lt;2x2xf32&gt;
  %1 = vector.extract %src[0, 0] : f32 from vector&lt;2x2xf32&gt;

Put differently, this PR removes the ambiguity when it comes to
non-indexed operands of vector.insert and vector.extract. By
requiring that only one form is used, it eliminates the flexibility of
allowing both, thereby simplifying the semantics.

For more context, see the related RFC:


Full diff: https://github.com/llvm/llvm-project/pull/121458.diff

4 Files Affected:

  • (modified) mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp (+33-6)
  • (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+10)
  • (modified) mlir/test/Dialect/Vector/invalid.mlir (+2-2)
  • (modified) mlir/test/Dialect/Vector/ops.mlir (+2-4)
diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
index cc5623068ab10..08f398a1c8ba6 100644
--- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
+++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
@@ -1294,6 +1294,10 @@ struct UnrollTransferReadConversion
 
   /// Rewrite the op: Unpack one dimension. Can handle masks, out-of-bounds
   /// accesses, and broadcasts and transposes in permutation maps.
+  ///
+  /// When unpacking rank-1 vectors (i.e. when the target rank is 0), replaces
+  /// `vector.transfer_read` with either `memref.load` or `tensor.extract` (for
+  /// MemRef and Tensor source, respectively).
   LogicalResult matchAndRewrite(TransferReadOp xferOp,
                                 PatternRewriter &rewriter) const override {
     if (xferOp.getVectorType().getRank() <= options.targetRank)
@@ -1324,6 +1328,8 @@ struct UnrollTransferReadConversion
     for (int64_t i = 0; i < dimSize; ++i) {
       Value iv = rewriter.create<arith::ConstantIndexOp>(loc, i);
 
+      // FIXME: Rename this lambda - it does much more than just
+      // in-bounds-check generation.
       vec = generateInBoundsCheck(
           rewriter, xferOp, iv, unpackedDim(xferOp), TypeRange(vecType),
           /*inBoundsCase=*/
@@ -1338,12 +1344,33 @@ struct UnrollTransferReadConversion
             insertionIndices.push_back(rewriter.getIndexAttr(i));
 
             auto inBoundsAttr = dropFirstElem(b, xferOp.getInBoundsAttr());
-            auto newXferOp = b.create<vector::TransferReadOp>(
-                loc, newXferVecType, xferOp.getBase(), xferIndices,
-                AffineMapAttr::get(unpackedPermutationMap(b, xferOp)),
-                xferOp.getPadding(), Value(), inBoundsAttr);
-            maybeAssignMask(b, xferOp, newXferOp, i);
-            return b.create<vector::InsertOp>(loc, newXferOp, vec,
+
+            // A value that's read after rank-reducing the original
+            // vector.transfer_read Op.
+            Value unpackedReadRes;
+            if (newXferVecType.getRank() != 0) {
+              // Unpacking Vector that's rank > 2
+              // (use vector.transfer_read to load a rank-reduced vector)
+              unpackedReadRes = b.create<vector::TransferReadOp>(
+                  loc, newXferVecType, xferOp.getBase(), xferIndices,
+                  AffineMapAttr::get(unpackedPermutationMap(b, xferOp)),
+                  xferOp.getPadding(), Value(), inBoundsAttr);
+              maybeAssignMask(b, xferOp,
+                              dyn_cast<vector::TransferReadOp>(
+                                  unpackedReadRes.getDefiningOp()),
+                              i);
+            } else {
+              // Unpacking Vector that's rank == 1
+              // (use memref.load/tensor.extract to load a scalar)
+              unpackedReadRes = dyn_cast<MemRefType>(xferOp.getBase().getType())
+                                    ? b.create<memref::LoadOp>(
+                                           loc, xferOp.getBase(), xferIndices)
+                                          .getResult()
+                                    : b.create<tensor::ExtractOp>(
+                                           loc, xferOp.getBase(), xferIndices)
+                                          .getResult();
+            }
+            return b.create<vector::InsertOp>(loc, unpackedReadRes, vec,
                                               insertionIndices);
           },
           /*outOfBoundsCase=*/
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 2a2357319bd23..dc4bcd9b6bd84 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -1383,6 +1383,11 @@ bool ExtractOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
 }
 
 LogicalResult vector::ExtractOp::verify() {
+  if (auto resTy = dyn_cast<VectorType>(getResult().getType()))
+    if (resTy.getRank() == 0)
+      return emitError(
+          "expected a scalar instead of a 0-d vector as the result type");
+
   // Note: This check must come before getMixedPosition() to prevent a crash.
   auto dynamicMarkersCount =
       llvm::count_if(getStaticPosition(), ShapedType::isDynamic);
@@ -3122,6 +3127,11 @@ void vector::InsertOp::build(OpBuilder &builder, OperationState &result,
 }
 
 LogicalResult InsertOp::verify() {
+  if (auto srcTy = dyn_cast<VectorType>(getValueToStoreType()))
+    if (srcTy.getRank() == 0)
+      return emitError(
+          "expected a scalar instead of a 0-d vector as the source operand");
+
   SmallVector<OpFoldResult> position = getMixedPosition();
   auto destVectorType = getDestVectorType();
   if (position.size() > static_cast<unsigned>(destVectorType.getRank()))
diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir
index 04810ed52584f..57ec12a8ccac1 100644
--- a/mlir/test/Dialect/Vector/invalid.mlir
+++ b/mlir/test/Dialect/Vector/invalid.mlir
@@ -260,8 +260,8 @@ func.func @insert_precise_position_overflow(%a: f32, %b: vector<4x8x16xf32>) {
 // -----
 
 func.func @insert_0d(%a: vector<f32>, %b: vector<4x8x16xf32>) {
-  // expected-error@+1 {{expected position attribute rank + source rank to match dest vector rank}}
-  %1 = vector.insert %a, %b[2, 6] : vector<f32> into vector<4x8x16xf32>
+  // expected-error@+1 {{expected a scalar instead of a 0-d vector as the source operand}}
+  %1 = vector.insert %a, %b[0, 0, 0] : vector<f32> into vector<4x8x16xf32>
 }
 
 // -----
diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir
index f3220aed4360c..7d43f2a84dc77 100644
--- a/mlir/test/Dialect/Vector/ops.mlir
+++ b/mlir/test/Dialect/Vector/ops.mlir
@@ -298,12 +298,10 @@ func.func @insert_val_idx(%a: f32, %b: vector<16xf32>, %c: vector<8x16xf32>,
 }
 
 // CHECK-LABEL: @insert_0d
-func.func @insert_0d(%a: f32, %b: vector<f32>, %c: vector<2x3xf32>) -> (vector<f32>, vector<2x3xf32>) {
+func.func @insert_0d(%a: f32, %b: vector<f32>) -> vector<f32> {
   // CHECK-NEXT: vector.insert %{{.*}}, %{{.*}}[] : f32 into vector<f32>
   %1 = vector.insert %a,  %b[] : f32 into vector<f32>
-  // CHECK-NEXT: vector.insert %{{.*}}, %{{.*}}[0, 1] : vector<f32> into vector<2x3xf32>
-  %2 = vector.insert %b,  %c[0, 1] : vector<f32> into vector<2x3xf32>
-  return %1, %2 : vector<f32>, vector<2x3xf32>
+  return %1 : vector<f32>
 }
 
 // CHECK-LABEL: @insert_poison_idx

@banach-space
Copy link
Contributor Author

Based on the discussion in the ODM, marking this as ready to review:

…tract

This patch enforces a restriction in the Vector dialect: the non-indexed
operands of `vector.insert` and `vector.extract` must no longer be 0-D
vectors. In other words, rank-0 vector types like `vector<f32>` are
disallowed as the source or result.

EXAMPLES
--------
The following are now **illegal** (note the use of `vector<f32>`):

```mlir
%0 = vector.insert %v, %dst[0, 0] : vector<f32> into vector<2x2xf32>
%1 = vector.extract %src[0, 0] : vector<f32> from vector<2x2xf32>
```

Instead, use scalars as the source and result types:

```mlir
  %0 = vector.insert %v, %dst[0, 0] : f32 into vector<2x2xf32>
  %1 = vector.extract %src[0, 0] : f32 from vector<2x2xf32>
```

This change serves three goals:

1. REDUCED AMBIGUITY
--------------------
By enforcing scalar-only semantics when n-k = 0, we eliminate ambiguity
in interpretation. Prior to this patch, both `f32` and `vector<f32>`
were accepted in practice, though only scalars were intended.

2. MATCH IMPLEMENTATION TO DOCUMENTATION
----------------------------------------
The current behavior contradicts the documented intent. For example,
vector.extract states:

> Degenerates to an element type if n-k is zero.

This patch enforces that intent in code.

3. ENSURE SYMMETRY BETWEEN INSERT AND EXTRACT
---------------------------------------------
With the stricter semantics in place, it’s natural and consistent to
make `vector.insert` behave symmetrically to `vector.extract`, i.e.,
degenerate the source type to a scalar when n = 0.

NOTES FOR REVIEWERS
-------------------
1. Main change is in "VectorOps.cpp", where stricter type checks are
   implemented.
2. Test updates in "invalid.mlir" and "ops.mlir" are minor cleanups to
   remove now-illegal examples.
2. Lowering changes in "VectorToSCF.cpp" are the main trade-off: we now
   avoid using `vector.transfer_read` for scalar loads and instead rely on
   `memref.load` / `tensor.extract`.

RELATED RFC
-----------
  * https://discourse.llvm.org/t/rfc-should-we-restrict-the-usage-of-0-d-vectors-in-the-vector-dialect
@banach-space banach-space force-pushed the andrzej/restrict_vec_insert_extract branch from 9f9491a to efc29a7 Compare June 15, 2025 10:55
@banach-space banach-space requested a review from kuhar as a code owner June 15, 2025 10:55
@banach-space
Copy link
Contributor Author

@dcaballe , @Groverkss , this is the patch that we discussed in the ODM last week. Would you be able to take a look some time soon? It would be great to make some progress on this.

Comment on lines 889 to 892
Inserts an n-D source vector (the value to store) into an (n + k)-D
destination vector at a specified k-D position. When n = 0, the source
degenerates to a scalar element inserted into the (0 + k)-D destination
vector.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a bit confusing when i read it together with vector.extract docs.

Can we do
n-D vector base vector (source for vector.extract, dest for vector.insert)
k-D position
(n-k)-D subvector, degenerates to scalar if k = n

it's a bit easier to follow then

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure. But let me use the naming scheme from #131602, so:

  • valueToStore + dest for vector.insert,
  • source for vector.extract.

Let me know what you think!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this was addressed, what i meant was to use same rank for same class of operands:

n-D vector --> source/dest
k-D position
(n-k)-D subvector (valueToStore, result vector), degenerates to a scalar if k = n.

I don't mind the naming scheme, but having consistent rank documentation is easier to read.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for clarifying, now I see what you meant. Could you check the latest revision?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, that was a Git failure on my part 🤦🏻

Could you check this commit that I've just pushed?

Comment on lines 1347 to 1373

// A value that's read after rank-reducing the original
// vector.transfer_read Op.
Value unpackedReadRes;
if (newXferVecType.getRank() != 0) {
// Unpacking Vector that's rank > 2
// (use vector.transfer_read to load a rank-reduced vector)
unpackedReadRes = b.create<vector::TransferReadOp>(
loc, newXferVecType, xferOp.getBase(), xferIndices,
AffineMapAttr::get(unpackedPermutationMap(b, xferOp)),
xferOp.getPadding(), Value(), inBoundsAttr);
maybeAssignMask(b, xferOp,
dyn_cast<vector::TransferReadOp>(
unpackedReadRes.getDefiningOp()),
i);
} else {
// Unpacking Vector that's rank == 1
// (use memref.load/tensor.extract to load a scalar)
unpackedReadRes = dyn_cast<MemRefType>(xferOp.getBase().getType())
? b.create<memref::LoadOp>(
loc, xferOp.getBase(), xferIndices)
.getResult()
: b.create<tensor::ExtractOp>(
loc, xferOp.getBase(), xferIndices)
.getResult();
}
return b.create<vector::InsertOp>(loc, unpackedReadRes, vec,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is unrelated to the patch and changing behavior of other transformations. For now, if the transfer_read returns a 0-D vector, we should extract a scalar and then insert it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a great point, let me update this, thanks!

@joker-eph joker-eph changed the title [mlir][vector] Restrict vector.insert/vector.extract [mlir][vector] Restrict vector.insert/vector.extract to disallow 0-d vectors Jun 20, 2025
@banach-space
Copy link
Contributor Author

Screenshot 2025-06-20 at 14 17 31

@joker-eph , thanks for updating the title! I just wanted to point out, only non-indexed arguments are disallowed to be rank-0. This change will still allow the indexed arguments to be rank-0. This is explained in the summary.

This has been a very long discussion, hence posting to avoid any potential confusion.

@joker-eph
Copy link
Collaborator

Yes that seemed clear from the description. The title should be as descriptive as possible while remaining short, up-to-you if you want to add more information there.

Copy link
Contributor

@dcaballe dcaballe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Thanks for moving this forward! Please, wait for one more approval before landing.

@banach-space banach-space requested a review from Groverkss June 23, 2025 15:14
…ctor.insert/vector.extract

Update the docs as suggested by Kunwar
Copy link
Member

@Groverkss Groverkss left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! LGTM

@banach-space banach-space merged commit e9e25f0 into llvm:main Jun 26, 2025
7 checks passed
@banach-space banach-space deleted the andrzej/restrict_vec_insert_extract branch June 26, 2025 08:48
anthonyhatran pushed a commit to anthonyhatran/llvm-project that referenced this pull request Jun 26, 2025
…vectors (llvm#121458)

This patch enforces a restriction in the Vector dialect: the non-indexed
operands of `vector.insert` and `vector.extract` must no longer be 0-D
vectors. In other words, rank-0 vector types like `vector<f32>` are
disallowed as the source or result.

EXAMPLES
--------
The following are now **illegal** (note the use of `vector<f32>`):

```mlir
%0 = vector.insert %v, %dst[0, 0] : vector<f32> into vector<2x2xf32>
%1 = vector.extract %src[0, 0] : vector<f32> from vector<2x2xf32>
```

Instead, use scalars as the source and result types:

```mlir
  %0 = vector.insert %v, %dst[0, 0] : f32 into vector<2x2xf32>
  %1 = vector.extract %src[0, 0] : f32 from vector<2x2xf32>
```

Note, this change serves three goals. These are summarised below.

## 1. REDUCED AMBIGUITY
By enforcing scalar-only semantics when the result (`vector.extract`)
or source (`vector.insert`) are rank-0, we eliminate ambiguity
in interpretation. Prior to this patch, both `f32` and `vector<f32>`
were accepted.

## 2. MATCH IMPLEMENTATION TO DOCUMENTATION
The current behaviour contradicts the documented intent. For example,
`vector.extract` states:

> Degenerates to an element type if n-k is zero.

This patch enforces that intent in code.

## 3. ENSURE SYMMETRY BETWEEN INSERT AND EXTRACT
With the stricter semantics in place, it’s natural and consistent to
make `vector.insert` behave symmetrically to `vector.extract`, i.e.,
degenerate the source type to a scalar when n = 0.

NOTES FOR REVIEWERS
-------------------
1. Main change is in "VectorOps.cpp", where stricter type checks are
   implemented.
2. Test updates in "invalid.mlir" and "ops.mlir" are minor cleanups to
   remove now-illegal examples.
2. Lowering changes in "VectorToSCF.cpp" are the main trade-off: we now
   require an additional `vector.extract` when a preceding
   `vector.transfer_read` generates a rank-0 vector.

RELATED RFC
-----------
*
https://discourse.llvm.org/t/rfc-should-we-restrict-the-usage-of-0-d-vectors-in-the-vector-dialect
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants